import torch; torch.manual_seed(42)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import numpy as np
from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os
import copy
import numpy as np
from joblib import dump, load
import sys

import torch
import random

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback
import pickle
from transformers import TrainerCallback
from sklearn.decomposition import PCA
from joblib import dump, load

n_components = 400
if_pca = True
load_pca = False
sched = False

n_stacked_pca = 4
n_stacked_vae = 128
delta_layer_shape = (1, 2, 4096)
input_shape_pca = (1, 2*n_stacked_pca, 4096)

if if_pca:
    input_shape_vae = (1, 2*n_stacked_vae, n_components)
else:
    input_shape_vae = (1, 2*n_stacked_vae, 4096)

full_delta_shape = (128, 2, 4096)
input_size = int(torch.prod(torch.tensor(input_shape_vae), 0))
output_shape = input_shape_vae

latent_dims = 8
hidden_dims = (512, 512)
n_inputs = int(128/n_stacked_vae)
epochs = 600
lr = 1e-4
alpha = 0.03


print("PCA:", "components", n_components, "stacked", n_stacked_pca)
print("VAE", "latent_dims", latent_dims, "hidden_dims", hidden_dims, "stacked", n_stacked_vae)
print("alpha", alpha)

pca_folder = "MultiheadPCA_reddit"
vae_folder = "llama_vae_state_reddit.pt"

def pca_preprocess(raw_delta):
    delta_pca_stacked = stack(raw_delta, n_stacked_pca)
    delta_pca = apply_pca_by_layer(delta_pca_stacked.cpu().detach(), pca_by_layer)
    delta_pca_unstacked = unstack(delta_pca, n_stacked_pca)
    return delta_pca_unstacked

def vae_preprocess(delta_pca):
    delta_vae_stacked = stack(delta_pca, n_stacked_vae)
    return delta_vae_stacked

def preprocess(raw_delta):
    return vae_preprocess(pca_preprocess(raw_delta))

def vae_postprocess(delta_vae_out):
    delta_vae_unstacked = unstack(delta_vae_out, n_stacked_vae)
    return delta_vae_unstacked

def pca_postprocess(delta_vae_unstacked):
    delta_out_pca_stacked = stack(delta_vae_unstacked, n_stacked_pca)
    delta_out_unpca = inverse_pca_by_layer(delta_out_pca_stacked, pca_by_layer)
    delta_out = torch.stack(unstack(delta_out_unpca, n_stacked_pca), axis=0)
    return delta_out

def postprocess(delta_vae_out):
    return pca_postprocess(vae_postprocess(delta_vae_out))

if if_pca:
    preprocess_fn = preprocess
    postprocess_fn = postprocess
else:
    preprocess_fn = vae_preprocess
    postprocess_fn = vae_postprocess

# The model that you want to train from the Hugging Face hub
model_name = "Llama-2-7b-hf"

################################################################################
# QLoRA parameters
################################################################################

# LoRA attention dimension
lora_r = 2

# Alpha parameter for LoRA scaling
lora_alpha = 8

# Dropout probability for LoRA layers
lora_dropout = 0.1

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

################################################################################
# TrainingArguments parameters
################################################################################

# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

# Number of training epochs
num_train_epochs = 1

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False

# Batch size per GPU for training
per_device_train_batch_size = 1

# Batch size per GPU for evaluation
per_device_eval_batch_size = 1

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1

# Enable gradient checkpointing
gradient_checkpointing = True

# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001

# Optimizer to use
optim = "paged_adamw_32bit"

# Learning rate schedule
lr_scheduler_type = "cosine"

# Number of training steps (overrides num_train_epochs)
max_steps = -1

# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03

# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True

# Save checkpoint every X updates steps
save_steps = 0

# Log every X updates steps
logging_steps = 25

################################################################################
# SFT parameters
################################################################################

# Maximum sequence length to use
max_seq_length = None

# Pack multiple short examples in the same input sequence to increase efficiency
packing = False

# Load the entire model on the GPU 0
device_map = {"": 0}

def load_pca(pca_folder):
    pca_by_layer = []
    for file_name in sorted(os.listdir(pca_folder), key=lambda x: int(x[:x.find('_')])):
         pca_by_layer.append(load(os.path.join(pca_folder, file_name)))
    return pca_by_layer


def apply_delta(model, layer_names, delta):
#    assert(delta.shape == full_delta_shape)
    with torch.no_grad():
        for i in range(len(layer_names)):
#            print(delta[i].shape)
            model.state_dict()[layer_names[i]].copy_(delta[i])
    return model

class Encoder(nn.Module):
    def __init__(self, input_size, latent_dims, hidden_dims):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_dims[0])

    def forward(self, x):
        x = torch.flatten(x)#, start_dim=1)
        x = F.relu(self.linear1(x))

        return x

class MultiheadVariationalEncoder(nn.Module):
    def __init__(self, input_size, latent_dims, hidden_dims, n_inputs):
        super(MultiheadVariationalEncoder, self).__init__()
        print("New Encoder")
        self.encoders = nn.ModuleList([Encoder(input_size, latent_dims, hidden_dims).cuda() for _ in range(n_inputs)])
        self.linear1 = nn.Linear(input_size, hidden_dims[0])
#        self.linear1 = nn.Linear(hidden_dims[0]*n_inputs, hidden_dims[1])
        self.linear2 = nn.Linear(hidden_dims[0]*n_inputs, latent_dims)
        self.linear3 = nn.Linear(hidden_dims[0]*n_inputs, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        out = []
        for i in range(x.shape[0]):
            out.append(self.encoders[i](x[i]))
            
        out = torch.flatten(torch.stack(out, axis=0)) 
#        out = F.relu(self.linear1(x))
#        out = F.relu(self.linear1(out))
        mu =  self.linear2(out)
        logvar = self.linear3(out)
        sigma = torch.exp(self.linear3(out))
        z =  mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

class Decoder(nn.Module):
    def __init__(self, output_size, output_shape, latent_dims, hidden_dims, n_inputs):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(hidden_dims[0], hidden_dims[0])
#        self.linearx = nn.Linear(hidden_dims[0], hidden_dims[0])
        self.linear2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.linear3 = nn.Linear(hidden_dims[1], output_size)

    def forward(self, z):
        z = F.relu(self.linear1(z))
#        z = F.relu(self.linearx(z))
        z = self.linear2(z)
        z = self.linear3(z)
        return z.reshape(output_shape)

class MultiheadDecoder(nn.Module):
    def __init__(self, output_size, output_shape, latent_dims, hidden_dims, n_inputs):
        super(MultiheadDecoder, self).__init__()
        print("New Decoder")
        self.linear1 = nn.Linear(latent_dims, hidden_dims[0]*n_inputs)
        self.decoders = nn.ModuleList([Decoder(output_size, output_shape, latent_dims, hidden_dims, n_inputs).cuda() for _ in range(n_inputs)])
        
    def forward(self, z):
        z = F.relu(self.linear1(z))
        z_split = torch.split(z, hidden_dims[0])
        assert len(z_split) == n_inputs
        
        out = []
        for i in range(n_inputs):
            out.append(self.decoders[i](z_split[i]))
        return torch.stack(out, axis=0)

class MultiheadVariationalAutoencoder(nn.Module):
    def __init__(self, input_size, output_shape, latent_dims, hidden_dims, n_inputs):
        super(MultiheadVariationalAutoencoder, self).__init__()
        self.encoder = MultiheadVariationalEncoder(input_size, latent_dims, hidden_dims, n_inputs)
        self.decoder = MultiheadDecoder(input_size, output_shape, latent_dims, hidden_dims, n_inputs)


    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

     
def train(autoencoder, data, opt, eval_delta_pca, epochs=25):
    for epoch in range(epochs):
        print('Epoch:', epoch)
        autoencoder.train()
        losses = []
        if sched:
               curr_alpha = alpha - (alpha/epochs) * epoch
        else:
               curr_alpha = alpha
        print("Curr alpha:", curr_alpha)
        for x in data:
            x = x.to(device) # GPU
#            print(x.shape)
            autoencoder.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x-x_hat)**2).sum() + curr_alpha*autoencoder.encoder.kl
            losses.append(loss.item())
            loss.backward()
            opt.step()
        autoencoder.eval()
        eval_hat = autoencoder(eval_delta_pca)
        eval_loss = ((eval_delta_pca-eval_hat)**2).sum() + curr_alpha*autoencoder.encoder.kl
        
        print("Train loss:", sum(losses)/len(losses), " Eval loss:", eval_loss.item())
    return autoencoder


def generate_examples(model, prompt = " ", n_examples=1):
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=100, do_sample=True, top_k=50, top_p=0.95)
    result = pipe(f"{prompt}")
    print(result[0]['generated_text'])

def cross_entropy_evaluation(model, tokenizer, test_list, bos="", eos=""):
    model.eval()
    total_eval_loss = 0
    tokenized_data = tokenizer([bos+txt+eos for txt in test_list])
    inputs = [(torch.tensor(tokenized_data['input_ids'][i]), torch.tensor(tokenized_data['attention_mask'][i])) for i in range(len(test_list))]
#    print(inputs)
    for b_input_ids, b_masks in inputs:
            b_input_ids = b_input_ids.unsqueeze(0)
            b_masks = b_masks.unsqueeze(0)

            b_labels = b_input_ids

            outputs  = model(b_input_ids,
                            attention_mask = b_masks,
                            labels=b_labels)

            logits = outputs.logits[:, :-1, :]

            labels = b_input_ids[:, 1:].contiguous()
#            print(labels)
#            print(logits.shape)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))

#            loss = outputs[0]
#            print(loss)

            batch_loss = torch.mean(loss).item()
            total_eval_loss += batch_loss
    avg_loss = total_eval_loss / len(test_list)
    return avg_loss

def compare_models(model, tokenizer, layer_names, delta1, delta2, data_list):
    loss_baseline = cross_entropy_evaluation(model, tokenizer, data_list)
    model1 = apply_delta(model, layer_names, delta2)
    loss1 = cross_entropy_evaluation(model1, tokenizer, data_list)
    model2 = apply_delta(model, layer_names, delta1)
    loss2 = cross_entropy_evaluation(model2, tokenizer, data_list)
    return loss_baseline, loss1, loss2

def perform_pca(delta_list, n_components, other_dim, n_filters):
    data_matrix = np.array([delta.cpu().detach().numpy() for delta in delta_list])
    reshaped_data = data_matrix.reshape(len(delta_list)*n_filters*other_dim, -1)
    print(reshaped_data.shape)
    pca = PCA(n_components=n_components, svd_solver='auto')
    pca.fit(reshaped_data)
    pca_deltas = []
    for delta in delta_list:
        pca_delta = pca.transform(delta.reshape(n_filters*other_dim, -1).detach().cpu())
        pca_delta = pca_delta.reshape(n_filters, other_dim, -1)
        pca_deltas.append(pca_delta)
    return pca, pca_deltas

def inverse_pca_by_layer(pca_delta, pca_by_layer):
    delta_out = []
    for i in range(len(pca_by_layer)):
        delta_out_layer = torch.tensor(pca_by_layer[i].inverse_transform(pca_delta[i].cpu().detach()).reshape((data_pca.shape[2], 4096)))
        delta_out.append(delta_out_layer)
    delta_out = torch.stack([layer.float() for layer in delta_out], axis=0)
    return delta_out

def apply_pca_by_layer(delta, pca_by_layer):
    delta_pca = []
    for i in range(len(pca_by_layer)):
        delta_pca_layer = torch.tensor(pca_by_layer[i].transform(delta[i].reshape(data_pca.shape[2], -1)).reshape(data_pca.shape[2], -1)).float()
        delta_pca.append(delta_pca_layer)
    delta_pca = torch.stack([layer.float() for layer in delta_pca], axis=0)
    return delta_pca

def stack(delta, n_stacked):
    return torch.stack([torch.vstack(list(delta[i:i+n_stacked])) for i in range(0, 128, n_stacked)], axis=0)

def unstack(delta_stacked, n_stacked):
    delta_unstacked = []
    for i in range(delta_stacked.shape[0]):
        layer = delta_stacked[i].squeeze()
        for k in range(n_stacked):
            delta_unstacked.append(layer[2*k:2*(k+1), :])
    return delta_unstacked


if __name__ == "__main__":
    data_folder = sys.argv[1]
    delta_folder = sys.argv[2]

    delta_list = []
    test_data_list = []
    folder_name_list = []
    layer_names = None
    og_shapes = None
    for folder_name in os.listdir(data_folder):
        curr_data_folder = os.path.join(data_folder, folder_name)
        curr_delta_folder = os.path.join(delta_folder, folder_name)
        folder_name_list.append(folder_name)
        delta_file = [file for file in os.listdir(curr_delta_folder) if '.pkl' in file][0]
        data_file = [file for file in os.listdir(curr_data_folder) if 'test' in file][0]
        print(folder_name)
        d = pickle.load(open(os.path.join(curr_delta_folder, delta_file), 'rb'))
        loaded_tensor_list = list(d.values())

        if layer_names is None:
            layer_names = d.keys()
        else:
            assert(layer_names == d.keys())

        if og_shapes is None:
           og_shapes = [t.shape for t in loaded_tensor_list]
        
        delta = torch.stack([t.reshape(delta_layer_shape[1:]) for t in loaded_tensor_list], axis=0)
	
        test = open(os.path.join(curr_data_folder, data_file), 'r').read().split('\n\n')[:-1]
        test_data_list.append(test)
        delta_list.append(delta)
    raw_data = torch.stack(delta_list, axis=0)
    print('Raw data shape', raw_data.shape)

    eval_delta = raw_data[0]
    eval_folder_name = folder_name_list[0]
    print("Validation is", eval_folder_name)
    folder_name_list = folder_name_list[1:]
    raw_data = raw_data[1:]
    eval_test_data = test_data_list[0]
    test_data_list = test_data_list[1:]

    processed_data = raw_data

    # PCA Stack
    if if_pca:
        delta_pca_stacked = []
        for delta in raw_data:
            delta_pca_stacked.append(stack(delta, n_stacked_pca))
        data_pca = torch.stack(delta_pca_stacked, axis=0)

        print("Data PCA shape", data_pca.shape)

        if load_pca:
           pca_by_layer = load_pca(pca_folder)
           pca_deltas_recollected = []
           for delta in data_pca:
              pca_deltas_recollected.append(apply_pca_by_layer(delta.cpu().detach(), pca_by_layer))
        else:
            if not os.path.exists(pca_folder):
                os.mkdir(pca_folder)

            data_by_layer = []
            for i in range(data_pca.shape[1]):
                data_by_layer.append([delta[i] for delta in data_pca])

            pca_by_layer = []
            pca_deltas_by_layer = []
            for i in range(data_pca.shape[1]):
                print(i)
                pca, pca_deltas = perform_pca(data_by_layer[i], n_components, data_pca.shape[2], n_filters=1)
                pca_by_layer.append(pca)
                pca_deltas_by_layer.append(pca_deltas)

            pca_deltas_recollected = []
            for i in range(data_pca.shape[0]):
                delta_pca = []
                for j in range(data_pca.shape[1]):
                     delta_pca.append(pca_deltas_by_layer[j][i])
                pca_deltas_recollected.append(torch.stack([torch.tensor(layer).float().cuda() for layer in delta_pca], axis=0))
    
            for i in range(len(pca_by_layer)):
                dump(pca_by_layer[i], pca_folder+'/'+str(i)+'_pca.joblib') 
    
        processed_data = []
        for delta_stacked in pca_deltas_recollected:
            processed_data.append(unstack(delta_stacked, n_stacked_pca))

        avg_explained = []
        for pca in pca_by_layer:
            var_exp = np.sum(pca.explained_variance_ratio_)
            print("Variance explained:", var_exp)
            avg_explained.append(var_exp)
        print("Avg variance explained:", sum(avg_explained)/len(avg_explained))

    delta_vae_stacked = []
    for delta in processed_data:
        delta_vae_stacked.append(stack(delta, n_stacked_vae))
    data_vae = torch.stack(delta_vae_stacked, axis=0)

    print("Data VAE shape", data_vae.shape)
     
    eval_delta_vae_stacked = preprocess_fn(eval_delta)
#    print("INPUT SIZE", input_size)
    vae = MultiheadVariationalAutoencoder(input_size, output_shape, latent_dims, hidden_dims, n_inputs).to(device)

    opt = torch.optim.Adam(vae.parameters(), lr=lr)

    vae = train(vae, data_vae, opt, eval_delta_vae_stacked.cuda(), epochs=epochs)

    torch.save(vae.state_dict(), vae_folder)

    print("----------EVALUATING FOR TRAIN------------")

    # Load tokenizer and model with QLoRA configuration
    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant,
    )

    # Check GPU compatibility with bfloat16
    if compute_dtype == torch.float16 and use_4bit:
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            print("=" * 80)
            print("Your GPU supports bfloat16: accelerate training with bf16=True")
            print("=" * 80)

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map=device_map
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load LLaMA tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

    peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
    )

    training_arguments = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        optim=optim,
        save_steps=save_steps,
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        fp16=fp16,
        bf16=bf16,
        max_grad_norm=max_grad_norm,
        max_steps=max_steps,
        warmup_ratio=warmup_ratio,
        group_by_length=group_by_length,
        lr_scheduler_type=lr_scheduler_type,
        evaluation_strategy="steps",  # Evaluate every `eval_steps`.
        eval_steps=0.25/num_train_epochs,  # Number of update steps between two evaluations.
        load_best_model_at_end=True,  # Load the best model found during training at the end of training
        metric_for_best_model="eval_loss",  # Use evaluation loss for early stopping
        greater_is_better=False,  # Smaller evaluation loss is better
    )

    # Set supervised fine-tuning parameters
    trainer = SFTTrainer(
        model=model,
        dataset_text_field="text",
        peft_config = peft_config,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=packing,
    )

    og_weights = []
    for name, param in model.named_parameters():
        if 'lora' in name:
            og_weights.append(param.clone())

    vae.eval()
    layer_names = list(layer_names)
    print('data shape', data_vae.shape)

    for idx in range(data_vae.shape[0]):
        print("Evalution for", folder_name_list[idx])
        to_evaluate = data_vae[idx]
        to_evaluate_out = vae(to_evaluate.cuda())
        print('to_evaluate_out', to_evaluate_out.shape)
        delta_unstacked = postprocess_fn(to_evaluate_out)
        delta_out = [delta_unstacked[i].reshape(og_shapes[i]) for i in range(128)]
        # Now the real delta
        real_delta = [raw_data[idx][i].reshape(og_shapes[i]) for i in range(128)]
   
        model = apply_delta(model, layer_names, og_weights)
        print(compare_models(model, tokenizer, layer_names, real_delta, delta_out, test_data_list[idx]))
        generate_examples(apply_delta(model, layer_names, delta_out), prompt="", n_examples=1)
        print("")
        generate_examples(apply_delta(model, layer_names, real_delta), prompt=tokenizer.bos_token+'I', n_examples=1)
#        loss = ((to_evaluate-to_evaluate_out)**2).sum() + vae.encoder.kl 
#        print("loss", loss)
        print("------------------------------------")
        if idx == 10:
#          with open(os.path.join(folder_name_list[idx]+'_llamadelta_recons.pkl'), 'wb') as f:
#               pickle.dump(delta_out, f)
           break

    print("----------EVALUATING FOR EVAL------------")
    to_evaluate = eval_delta.cpu().detach()
    print("eval_delta", eval_delta.shape)
    print("Evalution for", eval_folder_name)
    
    to_evaluate_out = vae(eval_delta_vae_stacked.cuda()) #inverse_pca_by_layer(apply_pca_by_layer(data[idx].cpu().detach(), pca_by_layer), pca_by_layer)
    print('to_evaluate_out', to_evaluate_out.shape)
    delta_unstacked = postprocess_fn(to_evaluate_out)
    delta_out = [delta_unstacked[i].reshape(og_shapes[i]) for i in range(128)]
    # Now the real delta
    real_delta = [to_evaluate[i].reshape(og_shapes[i]) for i in range(128)]
   
    model = apply_delta(model, layer_names, og_weights)
    print(compare_models(model, tokenizer, layer_names, real_delta, delta_out, eval_test_data))
    generate_examples(apply_delta(model, layer_names, delta_out), prompt="", n_examples=1)
    print("")
    generate_examples(apply_delta(model, layer_names, real_delta), prompt=tokenizer.bos_token+'I', n_examples=1)

